/**
 *
 *  @file CoroMapper.h
 *  @author An Tao
 *
 *  Copyright 2018, An Tao.  All rights reserved.
 *  https://github.com/an-tao/drogon
 *  Use of this source code is governed by a MIT license
 *  that can be found in the License file.
 *
 *  Drogon
 *
 */
#pragma once

#include <functional>
#include <tuple>

#ifdef __cpp_impl_coroutine
#include <drogon/orm/Mapper.h>
#include <drogon/utils/coroutine.h>

namespace drogon
{
namespace orm
{
namespace internal
{
template <typename ReturnType>
struct [[nodiscard]] MapperAwaiter : public CallbackAwaiter<ReturnType>
{
    using MapperFunction =
        std::function<void(std::function<void(ReturnType result)> &&,
                           std::function<void(const std::exception_ptr &)> &&)>;

    explicit MapperAwaiter(MapperFunction &&function)
        : function_(std::move(function))
    {
    }

    void await_suspend(std::coroutine_handle<> handle)
    {
        function_(
            [handle, this](ReturnType result) {
                this->setValue(std::move(result));
                handle.resume();
            },
            [handle, this](const std::exception_ptr &e) {
                this->setException(e);
                handle.resume();
            });
    }

  private:
    MapperFunction function_;
};
}  // namespace internal

/**
 * @brief This template implements coroutine interfaces of ORM. All the methods
 * of this template are coroutine versions of the synchronous interfaces of the
 * orm::Mapper template.
 *
 * @tparam T The type of the model.
 */
template <typename T>
class CoroMapper : public Mapper<T>
{
  public:
    using SingleRowCallback = typename Mapper<T>::SingleRowCallback;
    using MultipleRowsCallback = typename Mapper<T>::MultipleRowsCallback;
    using CountCallback = typename Mapper<T>::CountCallback;
    using ExceptPtrCallback = std::function<void(const std::exception_ptr &)>;

    explicit CoroMapper(DbClientPtr client) : Mapper<T>(std::move(client))
    {
    }

    using TraitsPKType = typename Mapper<T>::TraitsPKType;

    inline internal::MapperAwaiter<T> findByPrimaryKey(const TraitsPKType &key)
    {
        if constexpr (!std::is_same_v<typename T::PrimaryKeyType, void>)
        {
            auto lb = [this, key](SingleRowCallback &&callback,
                                  ExceptPtrCallback &&errCallback) mutable {
                static_assert(!std::is_same_v<typename T::PrimaryKeyType, void>,
                              "No primary key in the table!");
                static_assert(
                    internal::has_sqlForFindingByPrimaryKey<T>::value,
                    "No function member named sqlForFindingByPrimaryKey, "
                    "please "
                    "make sure that the model class is generated by the latest "
                    "version of drogon_ctl");
                // return findFutureOne(Criteria(T::primaryKeyName, key));
                std::string sql = T::sqlForFindingByPrimaryKey();
                if (this->forUpdate_)
                {
                    sql += " for update";
                }
                this->clear();
                auto binder = *(this->client_) << std::move(sql);
                this->outputPrimaryKeyToBinder(key, binder);

                binder >> [callback = std::move(callback),
                           errCallback](const Result &r) {
                    if (r.size() == 0)
                    {
                        errCallback(std::make_exception_ptr(
                            UnexpectedRows("0 rows found")));
                    }
                    else if (r.size() > 1)
                    {
                        errCallback(std::make_exception_ptr(
                            UnexpectedRows("Found more than one row")));
                    }
                    else
                    {
                        callback(T(r[0]));
                    }
                };
                binder >> std::move(errCallback);
                binder.exec();
            };
            return internal::MapperAwaiter<T>(std::move(lb));
        }
        else
        {
            LOG_FATAL << "The table must have a primary key";
            abort();
        }
    }

    // Query condition overrides

    /**
     * @brief Add a limit to the query.
     *
     * @param limit The limit
     * @return CoroMapper<T>& The CoroMapper itself.
     */
    CoroMapper<T> &limit(size_t limit)
    {
        Mapper<T>::limit(limit);
        return *this;
    }

    /**
     * @brief Add a offset to the query.
     *
     * @param offset The offset.
     * @return CoroMapper<T>& The CoroMapper itself.
     */
    CoroMapper<T> &offset(size_t offset)
    {
        Mapper<T>::offset(offset);
        return *this;
    }

    /**
     * @brief Set the order of the results.
     *
     * @param colName the column name, the results are sorted by that column
     * @param order Ascending or descending order
     * @return CoroMapper<T>& The CoroMapper itself.
     */
    CoroMapper<T> &orderBy(const std::string &colName,
                           const SortOrder &order = SortOrder::ASC)
    {
        Mapper<T>::orderBy(colName, order);
        return *this;
    }

    /**
     * @brief Set the order of the results.
     *
     * @param colIndex the column index, the results are sorted by that column
     * @param order Ascending or descending order
     * @return CoroMapper<T>& The CoroMapper itself.
     */
    CoroMapper<T> &orderBy(size_t colIndex,
                           const SortOrder &order = SortOrder::ASC)
    {
        Mapper<T>::orderBy(colIndex, order);
        return *this;
    }

    /**
     * @brief Set limit and offset to achieve pagination.
     * This method will override limit() and offset(), and will be override by
     * them.
     *
     * @param page The page number
     * @param perPage The number of columns per page
     * @return CoroMapper<T>& The CoroMapper itself.
     */
    CoroMapper<T> &paginate(size_t page, size_t perPage)
    {
        Mapper<T>::paginate(page, perPage);
        return *this;
    }

    /**
     * @brief Lock the result for updating.
     *
     * @return CoroMapper<T>& The CoroMapper itself.
     */
    CoroMapper<T> &forUpdate()
    {
        Mapper<T>::forUpdate();
        return *this;
    }

    // Read api for coroutines

    inline internal::MapperAwaiter<std::vector<T>> findAll()
    {
        return findBy(Criteria());
    }

    inline internal::MapperAwaiter<size_t> count(
        const Criteria &criteria = Criteria())
    {
        auto lb = [this, criteria](CountCallback &&callback,
                                   ExceptPtrCallback &&errCallback) {
            std::string sql = "select count(*) from ";
            sql += T::tableName;
            if (criteria)
            {
                sql += " where ";
                sql += criteria.criteriaString();
                sql = this->replaceSqlPlaceHolder(sql, "$?");
            }
            this->clear();
            auto binder = *(this->client_) << std::move(sql);
            if (criteria)
                criteria.outputArgs(binder);
            binder >> [callback = std::move(callback)](const Result &r) {
                assert(r.size() == 1);
                callback(r[0][(Row::SizeType)0].as<size_t>());
            };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<size_t>(std::move(lb));
    }

    inline internal::MapperAwaiter<T> findOne(const Criteria &criteria)
    {
        auto lb = [this, criteria](SingleRowCallback &&callback,
                                   ExceptPtrCallback &&errCallback) {
            std::string sql = "select * from ";
            sql += T::tableName;
            bool hasParameters = false;
            if (criteria)
            {
                sql += " where ";
                sql += criteria.criteriaString();
                hasParameters = true;
            }
            sql.append(this->orderByString_);
            if (this->limit_ > 0)
            {
                hasParameters = true;
                sql.append(" limit $?");
            }
            if (this->offset_ > 0)
            {
                hasParameters = true;
                sql.append(" offset $?");
            }
            if (hasParameters)
                sql = this->replaceSqlPlaceHolder(sql, "$?");
            if (this->forUpdate_)
            {
                sql += " for update";
            }
            auto binder = *(this->client_) << std::move(sql);
            if (criteria)
                criteria.outputArgs(binder);
            if (this->limit_ > 0)
                binder << this->limit_;
            if (this->offset_)
                binder << this->offset_;
            this->clear();
            binder >>
                [errCallback, callback = std::move(callback)](const Result &r) {
                    if (r.size() == 0)
                    {
                        errCallback(std::make_exception_ptr(
                            UnexpectedRows("0 rows found")));
                    }
                    else if (r.size() > 1)
                    {
                        errCallback(std::make_exception_ptr(
                            UnexpectedRows("Found more than one row")));
                    }
                    else
                    {
                        callback(T(r[0]));
                    }
                };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<T>(std::move(lb));
    }

    inline internal::MapperAwaiter<std::vector<T>> findBy(
        const Criteria &criteria)
    {
        auto lb = [this, criteria](MultipleRowsCallback &&callback,
                                   ExceptPtrCallback &&errCallback) {
            std::string sql = "select * from ";
            sql += T::tableName;
            bool hasParameters = false;
            if (criteria)
            {
                hasParameters = true;
                sql += " where ";
                sql += criteria.criteriaString();
            }
            sql.append(this->orderByString_);
            if (this->limit_ > 0)
            {
                hasParameters = true;
                sql.append(" limit $?");
            }
            if (this->offset_ > 0)
            {
                hasParameters = true;
                sql.append(" offset $?");
            }
            if (hasParameters)
                sql = this->replaceSqlPlaceHolder(sql, "$?");
            if (this->forUpdate_)
            {
                sql += " for update";
            }
            auto binder = *(this->client_) << std::move(sql);
            if (criteria)
                criteria.outputArgs(binder);
            if (this->limit_ > 0)
                binder << this->limit_;
            if (this->offset_)
                binder << this->offset_;
            this->clear();
            binder >> [callback = std::move(callback)](const Result &r) {
                std::vector<T> ret;
                for (auto const &row : r)
                {
                    ret.push_back(T(row));
                }
                callback(ret);
            };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<std::vector<T>>(std::move(lb));
    }

    inline internal::MapperAwaiter<T> insert(const T &obj)
    {
        auto lb = [this, obj](SingleRowCallback &&callback,
                              ExceptPtrCallback &&errCallback) {
            this->clear();
            bool needSelection = false;
            auto binder = *(this->client_)
                          << obj.sqlForInserting(needSelection);
            obj.outputArgs(binder);
            auto client = this->client_;
            binder >> [client,
                       callback = std::move(callback),
                       obj,
                       needSelection,
                       errCallback](const Result &r) {
                assert(r.affectedRows() == 1);
                if (client->type() == ClientType::PostgreSQL)
                {
                    if (needSelection)
                    {
                        assert(r.size() == 1);
                        callback(T(r[0]));
                    }
                    else
                    {
                        callback(obj);
                    }
                }
                else  // Mysql or Sqlite3
                {
                    auto id = r.insertId();
                    auto newObj = obj;
                    newObj.updateId(id);
                    if (needSelection)
                    {
                        auto tmp = Mapper<T>(client);
                        tmp.findByPrimaryKey(
                            newObj.getPrimaryKey(),
                            callback,
                            [errCallback](const DrogonDbException &err) {
                                errCallback(std::make_exception_ptr(
                                    Failure(err.base().what())));
                            });
                    }
                    else
                    {
                        callback(newObj);
                    }
                }
            };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<T>(std::move(lb));
    }

    inline internal::MapperAwaiter<size_t> update(const T &obj)
    {
        auto lb = [this, obj](CountCallback &&callback,
                              ExceptPtrCallback &&errCallback) {
            this->clear();
            static_assert(!std::is_same_v<typename T::PrimaryKeyType, void>,
                          "No primary key in the table!");
            std::string sql = "update ";
            sql += T::tableName;
            sql += " set ";
            for (auto const &colName : obj.updateColumns())
            {
                sql += colName;
                sql += " = $?,";
            }
            sql[sql.length() - 1] = ' ';  // Replace the last ','

            this->makePrimaryKeyCriteria(sql);

            sql = this->replaceSqlPlaceHolder(sql, "$?");
            auto binder = *(this->client_) << std::move(sql);
            obj.updateArgs(binder);
            this->outputPrimaryKeyToBinder(obj.getPrimaryKey(), binder);
            binder >> [callback = std::move(callback)](const Result &r) {
                callback(r.affectedRows());
            };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<size_t>(std::move(lb));
    }

    template <typename... TupleArgs, typename... Arguments>
    inline internal::MapperAwaiter<size_t> updateBy(
        const std::tuple<TupleArgs...> &colNames,
        const Criteria &criteria,
        Arguments &&...args)
    {
        static_assert(sizeof...(args) > 0);
        static_assert(sizeof...(args) ==
                      std::tuple_size_v<std::tuple<TupleArgs...>>);
        std::string sql = "update ";
        sql += T::tableName;
        sql += " set ";
        std::apply(
            [&sql](auto &&...name) {
                ((sql += std::string(name) + " = $?,"), ...);
            },
            colNames);
        sql[sql.length() - 1] = ' ';  // Replace the last ','

        return updateByHelper(std::move(sql),
                              criteria,
                              std::forward<Arguments>(args)...);
    }

    template <typename... Arguments>
    internal::MapperAwaiter<size_t> updateBy(
        const std::vector<std::string> &colNames,
        const Criteria &criteria,
        Arguments &&...args)
    {
        static_assert(sizeof...(args) > 0);
        assert(colNames.size() == sizeof...(args));
        std::string sql = "update ";
        sql += T::tableName;
        sql += " set ";
        for (auto const &colName : colNames)
        {
            sql += colName;
            sql += " = $?,";
        }
        sql[sql.length() - 1] = ' ';  // Replace the last ','

        return updateByHelper(std::move(sql),
                              criteria,
                              std::forward<Arguments>(args)...);
    }

    template <typename... Arguments>
    inline internal::MapperAwaiter<size_t> increment(
        const std::vector<std::string> &colNames,
        const Criteria &criteria,
        Arguments... args)
    {
        static_assert(sizeof...(args) > 0);
        assert(colNames.size() == sizeof...(args));
        std::string sql = "update ";
        sql += T::tableName;
        sql += " set ";

        std::vector<const char *> temps;
        (void)std::initializer_list<int>{(
            [&args, &temps] {
                args = (args < 0) ? (temps.push_back(" - $?,"), -args)
                                  : (temps.push_back(" + $?,"), args);
            }(),
            0)...};

        for (int i = 0; i < sizeof...(args); ++i)
        {
            const auto &colName = colNames[i];
            sql += colName;
            sql += " = ";
            sql += colName;
            sql += temps[i];
        }
        sql[sql.length() - 1] = ' ';  // Replace the last ','

        return updateByHelper(std::move(sql),
                              criteria,
                              std::forward<Arguments>(args)...);
    }

  private:
    template <typename... Arguments>
    internal::MapperAwaiter<size_t> updateByHelper(std::string &&sql,
                                                   const Criteria &criteria,
                                                   Arguments &&...args)
    {
        auto lb = [this,
                   sql = std::move(sql),
                   criteria,
                   ... args = std::forward<Arguments>(
                       args)](CountCallback &&callback,
                              ExceptPtrCallback &&errCallback) mutable {
            this->clear();

            if (criteria)
            {
                sql += " where ";
                sql += criteria.criteriaString();
            }

            sql = this->replaceSqlPlaceHolder(sql, "$?");
            auto binder = *(this->client_) << std::move(sql);
            (void)std::initializer_list<int>{(binder << args, 0)...};
            if (criteria)
                criteria.outputArgs(binder);
            binder >> [callback = std::move(callback)](const Result &r) {
                callback(r.affectedRows());
            };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<size_t>(std::move(lb));
    }

  public:
    inline internal::MapperAwaiter<size_t> deleteOne(const T &obj)
    {
        auto lb = [this, obj](CountCallback &&callback,
                              ExceptPtrCallback &&errCallback) {
            this->clear();
            static_assert(!std::is_same_v<typename T::PrimaryKeyType, void>,
                          "No primary key in the table!");
            std::string sql = "delete from ";
            sql += T::tableName;
            sql += " ";

            this->makePrimaryKeyCriteria(sql);

            sql = this->replaceSqlPlaceHolder(sql, "$?");
            auto binder = *(this->client_) << std::move(sql);
            this->outputPrimaryKeyToBinder(obj.getPrimaryKey(), binder);
            binder >> [callback = std::move(callback)](const Result &r) {
                callback(r.affectedRows());
            };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<size_t>(std::move(lb));
    }

    inline internal::MapperAwaiter<size_t> deleteBy(const Criteria &criteria)
    {
        auto lb = [this, criteria](CountCallback &&callback,
                                   ExceptPtrCallback &&errCallback) {
            this->clear();
            static_assert(!std::is_same_v<typename T::PrimaryKeyType, void>,
                          "No primary key in the table!");
            std::string sql = "delete from ";
            sql += T::tableName;

            if (criteria)
            {
                sql += " where ";
                sql += criteria.criteriaString();
                sql = this->replaceSqlPlaceHolder(sql, "$?");
            }

            auto binder = *(this->client_) << std::move(sql);
            if (criteria)
            {
                criteria.outputArgs(binder);
            }
            binder >> [callback = std::move(callback)](const Result &r) {
                callback(r.affectedRows());
            };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<size_t>(std::move(lb));
    }

    inline internal::MapperAwaiter<size_t> deleteByPrimaryKey(
        const TraitsPKType &key)
    {
        static_assert(!std::is_same_v<typename T::PrimaryKeyType, void>,
                      "No primary key in the table!");
        static_assert(
            internal::has_sqlForDeletingByPrimaryKey<T>::value,
            "No function member named sqlForDeletingByPrimaryKey, please "
            "make sure that the model class is generated by the latest "
            "version of drogon_ctl");
        auto lb = [this, key](CountCallback &&callback,
                              ExceptPtrCallback &&errCallback) {
            this->clear();
            auto binder = *(this->client_) << T::sqlForDeletingByPrimaryKey();
            this->outputPrimaryKeyToBinder(key, binder);
            binder >> [callback = std::move(callback)](const Result &r) {
                callback(r.affectedRows());
            };
            binder >> std::move(errCallback);
        };
        return internal::MapperAwaiter<size_t>(std::move(lb));
    }
};
}  // namespace orm
}  // namespace drogon
#endif
